package deeplearn;

import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.StackVertex;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.Map;

/**
 * 说明：
   将一个gan的网络拆分成 连个 一个 generate,另一个 distinct
 	互相copy参数。
 */
public class WGAN_GP_C {

	private static JFrame frame;
	private static JPanel panel;

	static double lr = 0.01;
	static String gmodel = "E:/face/gwgan.zip";
	static String dmodel = "E:/face/gwgan.zip";
	public static void main(String[] args) throws Exception {
		Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000);

 
		final GraphBuilder g = new NeuralNetConfiguration.Builder().updater(new Sgd(lr))
				.weightInit(WeightInit.XAVIER).graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("input1", "input2")
				.addLayer("g1",
						new DenseLayer.Builder().nIn(10).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"input1")
				.addLayer("g2",
						new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g1")
				.addLayer("g3",
						new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g2")
				.addVertex("stack", new StackVertex(), "input2", "g3")
				.addLayer("d1",
						new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"stack")
				.addLayer("d2",
						new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d1")
				.addLayer("d3",
						new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d2")
				.addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(128).nOut(1)
						.activation(Activation.SIGMOID).build(), "d3")
				.setOutputs("out");

		final GraphBuilder d = new NeuralNetConfiguration.Builder().updater(new Sgd(lr))
				.weightInit(WeightInit.XAVIER).graphBuilder().backpropType(BackpropType.Standard)
				.addInputs("input1", "input2")
				.addLayer("g1",
						new DenseLayer.Builder().nIn(10).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"input1")
				.addLayer("g2",
						new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g1")
				.addLayer("g3",
						new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"g2")
				.addVertex("stack", new StackVertex(), "input2", "g3")
				.addLayer("d1",
						new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"stack")
				.addLayer("d2",
						new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d1")
				.addLayer("d3",
						new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU)
								.weightInit(WeightInit.XAVIER).build(),
						"d2")
				.addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(128).nOut(1)
						.activation(Activation.SIGMOID).build(), "d3")
				.setOutputs("out");

		ComputationGraph gnet = new ComputationGraph(g.build());
		ComputationGraph dnet = new ComputationGraph(d.build());

		/*if (new File(model).exists()) {
			net = ComputationGraph.load(new File(model), true);
		}else{
			net = new ComputationGraph(graphBuilder.build());
		}*/


		gnet.init();
		dnet.init();
		System.out.println(gnet.summary());
		System.out.println(dnet.summary());
		UIServer uiServer = UIServer.getInstance();
		StatsStorage statsStorage = new InMemoryStatsStorage();
		uiServer.attach(statsStorage);
		gnet.setListeners(new ScoreIterationListener(100));
		gnet.getLayers();
		dnet.setListeners(new ScoreIterationListener(100));
		dnet.getLayers();





		DataSetIterator train = new MnistDataSetIterator(30, true, 12345);
		//按垂直方向（行顺序）堆叠数组构成一个新的数组
		INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1));
 
		INDArray labelG = Nd4j.ones(60, 1);
 
		for (int i = 1; i <= 100000; i++) {
			if (!train.hasNext()) {
				train.reset();
			}
			INDArray trueExp = train.next().getFeatures();

			INDArray z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 10 });
			MultiDataSet dataSetD = new MultiDataSet(new INDArray[] { z, trueExp },
					new INDArray[] { labelD });
			/*for(int m=0;m<2;m++){
				trainD(dnet, dataSetD);
			}*/
			trainD(dnet, dataSetD);
			updateD2G(dnet,gnet);
			z = Nd4j.rand(new NormalDistribution(),new long[] { 30, 10 });
			MultiDataSet dataSetG = new MultiDataSet(new INDArray[] { z, trueExp },
					new INDArray[] { labelG });
			trainG(gnet, dataSetG);
			updateG2D(dnet,gnet);
			/*long[] shape = indArray.shape();
			for(long l :shape){
				System.out.println(l);
			}*/
			if (i % 10 == 0) {
				Map<String, INDArray> map = gnet.feedForward(
						new INDArray[] { Nd4j.rand(new NormalDistribution(),new long[] { 50, 10 }), trueExp }, false);
				INDArray indArray = map.get("g3");// .reshape(20,28,28);
				long size = indArray.size(0);
				INDArray[] samples = new INDArray[(int)size];
				for (int j = 0; j < 8; j++) {
					samples[j] = indArray.getRow(j);
				}
				visualize(samples);
			}


			if (i % 100 == 0) {
			   gnet.save(new File(gmodel), true);
			   dnet.save(new File(dmodel), true);
			}
 
		}
 
	}
	private static void updateD2G(ComputationGraph ganD, ComputationGraph ganG) {
		//System.out.println("参数：A->B");
		ganG.getLayer("d1").setParam("W", ganD.getLayer("d1").getParam("W"));
		ganG.getLayer("d1").setParam("b", ganD.getLayer("d1").getParam("b"));
		ganG.getLayer("d2").setParam("W", ganD.getLayer("d2").getParam("W"));
		ganG.getLayer("d2").setParam("b", ganD.getLayer("d2").getParam("b"));
		ganG.getLayer("d3").setParam("W", ganD.getLayer("d3").getParam("W"));
		ganG.getLayer("d3").setParam("b", ganD.getLayer("d3").getParam("b"));
		ganG.getLayer("out").setParam("b", ganD.getLayer("out").getParam("b"));


	}
	private static void updateG2D(ComputationGraph ganD, ComputationGraph ganG) {
		//System.out.println("参数：B->A");
		ganD.getLayer("g1").setParam("W", ganG.getLayer("g1").getParam("W"));
		ganD.getLayer("g1").setParam("b", ganG.getLayer("g1").getParam("b"));
		ganD.getLayer("g2").setParam("W", ganG.getLayer("g2").getParam("W"));
		ganD.getLayer("g2").setParam("b", ganG.getLayer("g2").getParam("b"));
		ganD.getLayer("g3").setParam("W", ganG.getLayer("g3").getParam("W"));
		ganD.getLayer("g3").setParam("b", ganG.getLayer("g3").getParam("b"));
		//ganD.getLayer("out").setParam("b", ganG.getLayer("out").getParam("b"));

	}
 	// 判别模型  D(x)
	public static void trainD(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", 0);
		net.setLearningRate("g2", 0);
		net.setLearningRate("g3", 0);
		net.setLearningRate("d1", lr);
		net.setLearningRate("d2", lr);
		net.setLearningRate("d3", lr);
		net.setLearningRate("out", lr);
		net.fit(dataSet);
	}
	//生成模型 g(z)
	public static void trainG(ComputationGraph net, MultiDataSet dataSet) {
		net.setLearningRate("g1", lr);
		net.setLearningRate("g2", lr);
		net.setLearningRate("g3", lr);
		net.setLearningRate("d1", 0);
		net.setLearningRate("d2", 0);
		net.setLearningRate("d3", 0);
		net.setLearningRate("out", 0);
		net.fit(dataSet);
	}

	private static void visualize(INDArray[] samples) {
		if (frame == null) {
			frame = new JFrame();
			frame.setTitle("Viz");
			frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
			frame.setLayout(new BorderLayout());

			panel = new JPanel();

			panel.setLayout(new GridLayout(2, 4, 8, 8));
			frame.add(panel, BorderLayout.CENTER);
			frame.setVisible(true);
		}

		panel.removeAll();

		for (INDArray sample : samples) {
			if(sample == null || sample.size(0) == 0){
				continue;
			}
			panel.add(getImage(sample));
		}

		frame.revalidate();
		frame.pack();
	}

	private static JLabel getImage(INDArray tensor) {
		BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
		for (int i = 0; i < 784; i++) {
			bi.getRaster().setSample(i % 28, i / 28, 0, (int) (255 * tensor.getDouble(i)));
		}
		ImageIcon orig = new ImageIcon(bi);
		Image imageScaled = orig.getImage().getScaledInstance((int) (9 * 28), (int) (9 * 28),
				Image.SCALE_DEFAULT);
		ImageIcon scaled = new ImageIcon(imageScaled);

		return new JLabel(scaled);
	}
	private static BufferedImage imageFromINDArray(INDArray array) {
		array = array.reshape(28, 28);
		long[] shape = array.shape();
		int height = (int)shape[0];
		int width = (int)shape[1];
		BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY);
		for (int x = 0; x < width; x++) {
			for (int y = 0; y < height; y++) {
				System.out.println(array.getDouble(0, 0, y, x));
				int gray = (int) ((array.getDouble(0, 0, y, x)  + 1) * 127.5);

				// handle out of bounds pixel values
				gray = Math.min(gray, 255);
				gray = Math.max(gray, 0);

				image.getRaster().setSample(x, y, 0, gray);
			}
		}
		return image;
	}
}